from typing import Optional, Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.func import functional_call as argforward
from collections import OrderedDict

from copy import deepcopy

from neural_fields.data import CycloneNFDataset, CycloneNFDataLoader
from neural_fields.nf_train import integral_losses


def train_nf_maml(
    model: nn.Module,
    n_epochs: int,
    inner_lr: float,
    data: Sequence[CycloneNFDataset],
    loader: Sequence[CycloneNFDataLoader],
    device: torch.device,
    optim: torch.optim.Optimizer,
    use_flux_fields: bool = False,
    use_spectral: bool = False,
    cheat_integral: bool = False,
    sched: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
    use_print: bool = True,
):
    model.train()
    model.to(device)
    torch.set_float32_matmul_precision("high")
    best_loss = float("inf")
    best_model_state = deepcopy(model.state_dict())

    for e in range(n_epochs):
        eval_losses = []

        metaloss_total = 0.0
        # inner loop
        for task_loader in loader:
            task_loader.to(device)
            inner_loss = 0.0
            inner_weights = OrderedDict({k: v for k, v in model.named_parameters()})
            for f, coords in task_loader:
                pred_f = argforward(model, inner_weights, (coords,))
                loss = F.mse_loss(pred_f, f)
                grads = torch.autograd.grad(
                    loss, inner_weights.values(), create_graph=True
                )
                inner_weights = OrderedDict(
                    (name, param - inner_lr * grad)
                    for (name, param), grad in zip(inner_weights.items(), grads)
                )
                inner_loss += loss / len(task_loader)
            metaloss_total = metaloss_total + inner_loss / len(data)

        # outer optim
        optim.zero_grad()
        metagrads = torch.autograd.grad(metaloss_total, model.parameters())
        for w, g in zip(model.parameters(), metagrads):
            w.grad = g
        optim.step()
        if sched:
            sched.step()

        for task_data in data:
            task_data.to(device)
            with torch.no_grad():
                t_eval = integral_losses(
                    model,
                    task_data,
                    device,
                    use_flux_fields=use_flux_fields,
                    use_spectral=use_spectral,
                    cheat_integral=cheat_integral,
                )
                eval_losses.append(t_eval)

        eval_losses = {
            k: sum([v[k] for v in eval_losses]) / len(eval_losses)
            for k in eval_losses[0]
        }
        metaloss = metaloss_total.item()
        losses = {"train/metaloss": metaloss}
        losses.update({f"val/{k}": v for k, v in eval_losses.items()})

        if use_print:
            str_losses = ", ".join([f"{k}: {float(v):.6f}" for k, v in losses.items()])
            print(f"[{e}] {str_losses}")

        if metaloss < best_loss:
            best_loss = metaloss
            best_model_state = deepcopy(model.state_dict())

    best_model = deepcopy(model)
    best_model.load_state_dict(best_model_state)

    return model, best_model, {"meta_loss": best_loss}


def train_nf_functa(
    model: nn.Module,
    n_epochs: int,
    inner_lr: float,
    n_inner_steps: int,
    data: Sequence[CycloneNFDataset],
    loader: Sequence[CycloneNFDataLoader],
    device: torch.device,
    optim: torch.optim.Optimizer,
    use_flux_fields: bool = False,
    use_spectral: bool = False,
    cheat_integral: bool = False,
    sched: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
    use_print: bool = True,
):
    model.train()
    model.to(device)
    torch.set_float32_matmul_precision("high")
    best_loss = float("inf")
    best_model_state = deepcopy(model.state_dict())

    for e in range(n_epochs):
        eval_losses = []

        metaloss_total = 0.0
        # inner loop
        z_functa_tasks = []
        if n_inner_steps > 0:
            for task_loader in loader:
                task_loader.to(device)
                inner_loss = 0.0
                inner_z_functa = model.z_functa.detach().clone().requires_grad_()
                inner_optim = torch.optim.SGD([inner_z_functa], inner_lr)

                task_loader_iter = iter(task_loader)
                for _ in range(n_inner_steps):
                    f, coords = next(task_loader_iter)
                    pred_f = model(coords, cond=inner_z_functa)
                    loss = F.mse_loss(pred_f, f)
                    loss.backward(create_graph=True)
                    inner_optim.step()
                    inner_loss += loss
                inner_loss = inner_loss / n_inner_steps
                z_functa_tasks.append(inner_z_functa)
                metaloss_total = metaloss_total + inner_loss / len(data)

        # outer optim
        outer_loss = 0.0
        for task_cond, task_loader in zip(z_functa_tasks, loader):
            task_loader.to(device)
            for f, coords in task_loader:
                pred_f = model(coords, task_cond)
                loss = F.mse_loss(pred_f, f)
                outer_loss += loss / len(task_loader)
        outer_loss = outer_loss / len(loader)
        optim.zero_grad()
        outer_loss.backward()
        optim.step()

        if sched:
            sched.step()

        for task_data in data:
            task_data.to(device)
            with torch.no_grad():
                t_eval = integral_losses(
                    model,
                    task_data,
                    device,
                    use_flux_fields=use_flux_fields,
                    use_spectral=use_spectral,
                    cheat_integral=cheat_integral,
                )
                eval_losses.append(t_eval)

        eval_losses = {
            k: sum([v[k] for v in eval_losses]) / len(eval_losses)
            for k in eval_losses[0]
        }
        metaloss = metaloss_total.item()
        losses = {"train/metaloss": metaloss, "train/outer_loss": outer_loss}
        losses.update({f"val/{k}": v for k, v in eval_losses.items()})

        if use_print:
            str_losses = ", ".join([f"{k}: {float(v):.6f}" for k, v in losses.items()])
            print(f"[{e}] {str_losses}")

        if metaloss < best_loss:
            best_loss = metaloss
            best_model_state = deepcopy(model.state_dict())

    best_model = deepcopy(model)
    best_model.load_state_dict(best_model_state)

    return (model, z_functa_tasks), best_model, {"meta_loss": best_loss}
